{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6baff84d",
   "metadata": {},
   "source": [
    "Copyright (c) 2021, salesforce.com, inc.\\\n",
    "All rights reserved.\\\n",
    "SPDX-License-Identifier: BSD-3-Clause\\\n",
    "For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c618d07b",
   "metadata": {},
   "source": [
    "### Colab"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eba754b9",
   "metadata": {},
   "source": [
    "Try this notebook on [Colab](http://colab.research.google.com/github/salesforce/warp-drive/blob/master/tutorials/tutorial-5-training_with_warp_drive.ipynb)!"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eda63add",
   "metadata": {},
   "source": [
    "## ⚠️ PLEASE NOTE:\n",
    "This notebook runs on a GPU runtime.\\\n",
    "If running on Colab, choose Runtime > Change runtime type from the menu, then select 'GPU' in the dropdown."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ab2fe5fd",
   "metadata": {},
   "source": [
    "# Introduction"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1cebd3ce",
   "metadata": {},
   "source": [
    "In this tutorial, we describe how to\n",
    "- Use the WarpDrive framework to perform end-to-end training of multi-agent reinforcement learning (RL) agents.\n",
    "- Visualize the behavior using the trained policies.\n",
    "\n",
    "In case you haven't familiarized yourself with WarpDrive, please see the other tutorials we have prepared for you\n",
    "- [WarpDrive basics](https://www.github.com/salesforce/warp-drive/blob/master/tutorials/tutorial-1-warp_drive_basics.ipynb)\n",
    "- [WarpDrive sampler](https://www.github.com/salesforce/warp-drive/blob/master/tutorials/tutorial-2-warp_drive_sampler.ipynb)\n",
    "- [WarpDrive reset and log controller](https://www.github.com/salesforce/warp-drive/blob/master/tutorials/tutorial-3-warp_drive_reset_and_log.ipynb)\n",
    "\n",
    "Please also see our [tutorial](https://www.github.com/salesforce/warp-drive/blob/master/tutorials/tutorial-4-create_custom_environments.ipynb) on creating your own RL environment in CUDA C. Once you have your own environment in CUDA C, this tutorial explains how to integrate it with the WarpDrive framework to perform training."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1e0b31c7",
   "metadata": {},
   "source": [
    "## Dependencies"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eccca4cb",
   "metadata": {},
   "source": [
    "You can install the warpdrive package using\n",
    "\n",
    "- the pip package manager OR\n",
    "- by cloning the warp_drive package and installing the requirements (we shall use this when running on Colab)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8608de0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "IN_COLAB = 'google.colab' in sys.modules\n",
    "\n",
    "if IN_COLAB:\n",
    "    ! git clone https://github.com/salesforce/warp-drive.git \n",
    "    % cd warp-drive\n",
    "    ! pip install -e .\n",
    "    % cd tutorials\n",
    "else:\n",
    "    ! pip install rl_warp_drive"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7dca5bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "from warp_drive.env_wrapper import EnvWrapper\n",
    "from warp_drive.training.models.fully_connected import FullyConnected\n",
    "\n",
    "from example_envs.tag_continuous.tag_continuous import TagContinuous\n",
    "from utils.generate_rollout_animation import generate_env_rollout_animation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fede7c9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from gym.spaces import Discrete, MultiDiscrete\n",
    "from IPython.display import HTML\n",
    "import json\n",
    "import numpy as np\n",
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "64587609",
   "metadata": {},
   "source": [
    "## Training the tag-continuous environment with WarpDrive"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d8d1c42",
   "metadata": {},
   "source": [
    "For your convenience, there are end-to-end RL training scripts at `warp_drive/training/example_training_scripts.py`. Currently, it supports training both the discrete and the continuous versions of Tag."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "93b97fca",
   "metadata": {},
   "source": [
    "In order to run the training for these environments, we first need to configure the *run config*: the set of environment, training, and model parameters. \n",
    "\n",
    "The run configs for each of the environments are listed in `warp_drive/training/run_configs`, and a sample set of good configs for the **tag-continuous** environment is shown below. \n",
    "\n",
    "In this tutorial, we'll use $5$ taggers and $100$ runners in a $20 \\times 20$ square grid. The taggers and runners have the same skill level, i.e., the runners can move just as fast as the taggers. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7f275123",
   "metadata": {},
   "source": [
    "```yaml\n",
    "# YAML configuration for the tag continuous environment\n",
    "name: \"tag_continuous\"\n",
    "\n",
    "# Environment settings\n",
    "env:\n",
    "    num_taggers: 5\n",
    "    num_runners: 100\n",
    "    grid_length: 20\n",
    "    episode_length: 500\n",
    "    max_acceleration: 0.1\n",
    "    min_acceleration: -0.1\n",
    "    max_turn: 2.35  # 3*pi/4 radians\n",
    "    min_turn: -2.35  # -3*pi/4 radians\n",
    "    num_acceleration_levels: 20\n",
    "    num_turn_levels: 20\n",
    "    skill_level_runner: 1\n",
    "    skill_level_tagger: 1\n",
    "    seed: 274880\n",
    "    use_full_observation: False\n",
    "    runner_exits_game_after_tagged: True\n",
    "    num_other_agents_observed: 10\n",
    "    tag_reward_for_tagger: 10.0\n",
    "    tag_penalty_for_runner: -10.0\n",
    "    step_penalty_for_tagger: -0.00\n",
    "    step_reward_for_runner: 0.00\n",
    "    edge_hit_penalty: -0.0\n",
    "    end_of_game_reward_for_runner: 1.0\n",
    "    tagging_distance: 0.02\n",
    "    \n",
    "# Trainer settings\n",
    "trainer:\n",
    "    num_envs: 1  # Number of environment replicas\n",
    "    num_episodes: 1000000000  # Number of episodes to run the training for\n",
    "    train_batch_size: 100  # total batch size used for training per iteration (across all the environments)\n",
    "    algorithm: \"A2C\"  # trainer algorithm\n",
    "    vf_loss_coeff: 1  # loss coefficient for the value function loss\n",
    "    entropy_coeff: 0.05  # coefficient for the entropy component of the loss\n",
    "    clip_grad_norm: True  # fla indicating whether to clip the gradient norm or not\n",
    "    max_grad_norm: 0.5  # when clip_grad_norm is True, the clip level\n",
    "    normalize_advantage: False  # flag indicating whether to normalize advantage or not\n",
    "    normalize_return: False  # flag indicating whether to normalize return or not\n",
    "\n",
    "# Policy network settings\n",
    "policy:  # list all the policies below\n",
    "    runner:\n",
    "        to_train: True\n",
    "        name: \"fully_connected\"\n",
    "        gamma: 0.98  # discount rate gamms\n",
    "        lr: 0.005  # learning rate\n",
    "        model:        \n",
    "            fc_dims: [256, 256]  # dimension(s) of the fully connected layers as a list\n",
    "            model_ckpt_filepath: \"\"\n",
    "    tagger:\n",
    "        to_train: True\n",
    "        name: \"fully_connected\"\n",
    "        gamma: 0.98\n",
    "        lr: 0.002\n",
    "        model:\n",
    "            fc_dims: [256, 256]\n",
    "            model_ckpt_filepath: \"\"\n",
    "            \n",
    "# Checkpoint saving setting\n",
    "saving:\n",
    "    print_metrics_freq: 100  # How often (in iterations) to print the metrics\n",
    "    save_model_params_freq: 5000  # How often (in iterations) to save the model parameters\n",
    "    basedir: \"/tmp\"  # base folder used for saving\n",
    "    tag: \"800runners_5taggers_bs100\"\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dad27901",
   "metadata": {},
   "source": [
    "Next, we also need to specify a mapping from the policy to agent indices trained using that policy. This needs to be set in `warp_drive/training/example_training_script.py`. As such, we have the tagger and runner policies, and we map those to the corresponding agents, as in\n",
    "\n",
    "```python\n",
    "    policy_tag_to_agent_id_map = {\n",
    "        \"tagger\": list(envObj.env.taggers),\n",
    "        \"runner\": list(envObj.env.runners),\n",
    "    }\n",
    "```\n",
    "\n",
    "\n",
    "Note that if you wish to use just a single policy across all the agents, or many other policies, you will need to update the run configuration as well as the policy_to_agent_id_mappping.\n",
    "\n",
    "For example, for using a shared policy across all agents (say `shared_policy`), for example, you can just use the run configuration as\n",
    "```python\n",
    "    \"policy\": {\n",
    "        \"shared_policy\": {\n",
    "            \"to_train\": True,\n",
    "            \"name\": \"fully_connected\",\n",
    "            \"gamma\": 0.98,\n",
    "            \"lr\": 0.002,\n",
    "            \"model\": {\n",
    "                \"num_fc\": 2,\n",
    "                \"fc_dim\": 256,\n",
    "                \"model_ckpt_filepath\": \"\",\n",
    "            },\n",
    "        },\n",
    "    },\n",
    "```\n",
    "and also set all the agent ids to use this shared policy\n",
    "```python\n",
    "    policy_tag_to_agent_id_map = {\n",
    "        \"shared_policy\": np.arange(envObj.env.num_agents),\n",
    "    }\n",
    "```\n",
    "\n",
    "**Note: make sure the `policy` keys and the `policy_tag_to_agent_id_map` keys are identical.**"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "88c9cdb6",
   "metadata": {},
   "source": [
    "Once the run configuration and the policy to agent id mapping are set, you can invoke training by using\n",
    "```shell\n",
    "python warp_drive/training/example_training_script.py --env <ENV-NAME>\n",
    "```\n",
    "where `<ENV-NAME>` can be `tag_gridworld` or `tag_continuous` (or any new env that you build). And that's it!\n",
    "\n",
    "The training script performs the following in order\n",
    "1. Creates the pertinent environment object (with the `use_cuda` flag set to True).\n",
    "2. Creates and pushes observtion, action, reward and done placeholder data arrays to the device.\n",
    "3. Creates the trainer object using the environment object, the run configuration, and policy to agent id mapping.\n",
    "4. Invokes trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df7cd3dd",
   "metadata": {},
   "source": [
    "## Visualizing the trainer policies"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d50295e4",
   "metadata": {},
   "source": [
    "In the run config, there's a `save_model_params_freq` parameter that can be set to frequently keep saving model checkpoints. With the model checkpoints, we can initialize the neural network weights and generate a full episode rollout. \n",
    "\n",
    "We can find an example run config and the trained tagger and runner policy model weights (after about 20M steps) in the `assets/tag_continuous_training/` folder."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e17f926",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the run config.\n",
    "with open(\"assets/tag_continuous_training/run_config.json\") as f:\n",
    "    run_config = json.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed028cec",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the environment object.\n",
    "env_wrapper = EnvWrapper(TagContinuous(**run_config['env']))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e729caa9",
   "metadata": {},
   "source": [
    "The taggers (runners) use a shared tagger (runner) policy model. The `policy_tag_to_agent_id_map` describes this mapping."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0c6becf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the policy tag to agent id mapping.\n",
    "policy_tag_to_agent_id_map = {\n",
    "    \"tagger\": list(env_wrapper.env.taggers),\n",
    "    \"runner\": list(env_wrapper.env.runners),\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "893662d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Step through the environment.\n",
    "# The environment(s) store and update the rollout data internally in env.global_state.\n",
    "\n",
    "\n",
    "def generate_rollout_inplace(env_wrapper, run_config, load_model_weights=False):\n",
    "    assert env_wrapper is not None\n",
    "    assert run_config is not None\n",
    "    \n",
    "    obs = env_wrapper.reset_all_envs()\n",
    "    action_space = env_wrapper.env.action_space[0]        \n",
    "        \n",
    "    # Instantiate the policy models.\n",
    "    policy_models = {}\n",
    "\n",
    "    for policy in policy_tag_to_agent_id_map:\n",
    "        policy_config = run_config[\"policy\"][policy]\n",
    "        if policy_config[\"name\"] == \"fully_connected\":\n",
    "            policy_models[policy] = FullyConnected(\n",
    "                env=env_wrapper,\n",
    "                model_config=policy_config[\"model\"],\n",
    "                policy=policy,\n",
    "                policy_tag_to_agent_id_map=policy_tag_to_agent_id_map,\n",
    "            )\n",
    "        else:\n",
    "            raise NotImplementedError\n",
    "    \n",
    "    if load_model_weights:\n",
    "        print(f\"Loading saved weights into the policy models...\")\n",
    "        for policy in policy_models:\n",
    "            state_dict_filepath = f\"assets/tag_continuous_training/{policy}_after_training.state_dict\"\n",
    "            policy_models[policy].load_state_dict(torch.load(state_dict_filepath))            \n",
    "            print(f\"Loaded ckpt {state_dict_filepath} for {policy} policy model.\")\n",
    "\n",
    "    for t in range(env_wrapper.env.episode_length):\n",
    "        stacked_obs = np.stack(obs.values()).astype(np.float32)\n",
    "        \n",
    "        \n",
    "        # Create dict to collect the actions for all agents.\n",
    "        if isinstance(action_space, Discrete):\n",
    "            actions = {agent_id: 0 for agent_id in range(env_wrapper.env.num_agents)}\n",
    "        elif isinstance(action_space, MultiDiscrete):\n",
    "            actions = {agent_id: [0, 0] for agent_id in range(env_wrapper.env.num_agents)}\n",
    "        else:\n",
    "            raise NotImplementedError\n",
    "\n",
    "        \n",
    "        # Sample actions for all agents.\n",
    "        for policy in policy_models:\n",
    "            agent_ids = policy_tag_to_agent_id_map[policy]\n",
    "            probabilities, vals = policy_models[policy](\n",
    "                obs=torch.from_numpy(stacked_obs[agent_ids])\n",
    "            )\n",
    "            if isinstance(action_space, Discrete):\n",
    "                for idx, probs in enumerate(probabilities):\n",
    "                    sampled_actions = torch.multinomial(probs, num_samples=1)\n",
    "                    for sample_action_idx, action in enumerate(sampled_actions):\n",
    "                        actions[agent_ids[sample_action_idx]] = action.numpy()[0]\n",
    "\n",
    "            elif isinstance(action_space, MultiDiscrete):\n",
    "                for idx, probs in enumerate(probabilities):\n",
    "                    sampled_actions = torch.multinomial(probs, num_samples=1)\n",
    "                    for sample_action_idx, action in enumerate(sampled_actions):\n",
    "                        actions[agent_ids[sample_action_idx]][idx] = action.numpy()[0]\n",
    "            else:\n",
    "                raise NotImplementedError\n",
    "        \n",
    "        \n",
    "        # Execute actions in the environment.\n",
    "        obs, rew, done, info = env_wrapper.step(actions)\n",
    "        \n",
    "        if done[\"__all__\"]:\n",
    "            break"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4ca1e7e5",
   "metadata": {},
   "source": [
    "## Visualize the environment before training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6171d132",
   "metadata": {},
   "outputs": [],
   "source": [
    "generate_rollout_inplace(env_wrapper, run_config)\n",
    "# Visualize the env at t=0\n",
    "anim = generate_env_rollout_animation(env_wrapper.env, i_start=1, fps=50, fig_width=6, fig_height=6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e91a1d31",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Now, visualize the entire episode roll-out\n",
    "HTML(anim.to_html5_video())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b3af2085",
   "metadata": {},
   "source": [
    "In the visualization above, the large purple dots represent the taggers, while the smaller blue dots represent the runners. Before training, the runners and taggers move around randomly, and that only results in some runners getting tagged, just by chance."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a42d8e58",
   "metadata": {},
   "source": [
    "## Visualize the environment after training (for about 20M steps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47045f4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "generate_rollout_inplace(env_wrapper, run_config, load_model_weights=True)\n",
    "# Visualize the env at t=0\n",
    "anim = generate_env_rollout_animation(env_wrapper.env, i_start=1, fps=50, fig_width=6, fig_height=6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52fb876b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Now, visualize the entire episode roll-out\n",
    "HTML(anim.to_html5_video())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1b5999b8",
   "metadata": {},
   "source": [
    "After training, the runners learn to run away from the taggers, and the taggers learn to chase them; there are some instances where we see that taggers also team up to chase and tag the runners. Overall, about 80% of the runners are caught now."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20e75f6f",
   "metadata": {},
   "source": [
    "# Learn More and Explore our Tutorials!"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4eb1a7b8",
   "metadata": {},
   "source": [
    "You've now seen the entire end-to-end multi-agent RL pipeline!\n",
    "\n",
    "For your reference, all our tutorials are here:\n",
    "- [A simple end-to-end RL training example](https://www.github.com/salesforce/warp-drive/blob/master/tutorials/simple-end-to-end-example.ipynb)\n",
    "- [WarpDrive basics](https://www.github.com/salesforce/warp-drive/blob/master/tutorials/tutorial-1-warp_drive_basics.ipynb)\n",
    "- [WarpDrive sampler](https://www.github.com/salesforce/warp-drive/blob/master/tutorials/tutorial-2-warp_drive_sampler.ipynb)\n",
    "- [WarpDrive reset and log](https://www.github.com/salesforce/warp-drive/blob/master/tutorials/tutorial-3-warp_drive_reset_and_log.ipynb)\n",
    "- [Creating custom environments](https://www.github.com/salesforce/warp-drive/blob/master/tutorials/tutorial-4-create_custom_environments.ipynb)\n",
    "- [Training with WarpDrive](https://www.github.com/salesforce/warp-drive/blob/master/tutorials/tutorial-5-training_with_warp_drive.ipynb)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
